{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp mtl_model.base\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MTLBase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from copy import copy\n",
    "from typing import Dict, Tuple\n",
    "\n",
    "import tensorflow as tf\n",
    "from m3tl.utils import dispatch_features, get_phase\n",
    "\n",
    "\n",
    "class MTLBase(tf.keras.Model):\n",
    "    def __init__(self, params, name:str, *args, **kwargs):\n",
    "        super(MTLBase, self).__init__(name, *args, **kwargs)\n",
    "        self.params = params\n",
    "        self.available_extract_target = copy(self.params.problem_list)\n",
    "        self.available_extract_target.append('all')\n",
    "        self.problem_list = self.params.problem_list\n",
    "\n",
    "    def extract_feature(self, extract_problem: str, feature_dict: dict, hidden_feature_dict: dict):\n",
    "\n",
    "        mode = get_phase()\n",
    "        if extract_problem not in self.available_extract_target:\n",
    "            raise ValueError('Tried to extract feature {0}, available extract problem: {1}'.format(\n",
    "                extract_problem, self.available_extract_target))\n",
    "        \n",
    "        # if key contains problem, return directly\n",
    "        if extract_problem in feature_dict and extract_problem in hidden_feature_dict:\n",
    "            return feature_dict[extract_problem], hidden_feature_dict[extract_problem]\n",
    "\n",
    "        # use dispatch function to extract record based on loss multiplier\n",
    "        if 'all' in feature_dict and 'all' in hidden_feature_dict:\n",
    "            return dispatch_features(\n",
    "                features=feature_dict['all'], hidden_feature=hidden_feature_dict['all'], \n",
    "                problem=extract_problem, mode=mode)\n",
    "        return dispatch_features(\n",
    "                features=feature_dict, hidden_feature=hidden_feature_dict, \n",
    "                problem=extract_problem, mode=mode)\n",
    "\n",
    "    def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):\n",
    "        raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"MTLBase.extract_feature\" class=\"doc_header\"><code>MTLBase.extract_feature</code><a href=\"__main__.py#L15\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>MTLBase.extract_feature</code>(**`extract_problem`**:`str`, **`feature_dict`**:`dict`, **`hidden_feature_dict`**:`dict`)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(MTLBase.extract_feature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract features(inputs) and hidden features(body model output tensors) from features and hidden_featues dicts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "Adding new problem weibo_cws, problem type: seq_tag\n",
      "Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "Adding new problem weibo_fake_cls, problem type: cls\n",
      "Adding new problem weibo_masklm, problem type: masklm\n",
      "Adding new problem weibo_pretrain, problem type: pretrain\n",
      "Adding new problem weibo_fake_regression, problem type: regression\n",
      "Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "Adding new problem weibo_premask_mlm, problem type: premask_mlm\n",
      "INFO:tensorflow:sampling weights: \n",
      "INFO:tensorflow:weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit: 0.2631578947368421\n",
      "INFO:tensorflow:weibo_fake_multi_cls: 0.2631578947368421\n",
      "INFO:tensorflow:weibo_masklm: 0.2236842105263158\n",
      "INFO:tensorflow:weibo_premask_mlm: 0.25\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "404 Client Error: Not Found for url: https://huggingface.co/voidful/albert_chinese_tiny/resolve/main/tf_model.h5\n",
      "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFAlbertModel: ['predictions.LayerNorm.weight', 'predictions.decoder.weight', 'predictions.dense.weight', 'predictions.decoder.bias', 'predictions.dense.bias', 'predictions.bias', 'predictions.LayerNorm.bias']\n",
      "- This IS expected if you are initializing TFAlbertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TFAlbertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of TFAlbertModel were initialized from the PyTorch model.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFAlbertModel for predictions without further training.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Modal Type id mapping: \n",
      " {\n",
      "    \"class\": 0,\n",
      "    \"image\": 1,\n",
      "    \"text\": 2\n",
      "}\n",
      "WARNING:tensorflow:AutoGraph could not transform <bound method Socket.send of <zmq.sugar.socket.Socket object at 0x7f00a53a4980>> and will run it as-is.\n",
      "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
      "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
      "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
      "WARNING: AutoGraph could not transform <bound method Socket.send of <zmq.sugar.socket.Socket object at 0x7f00a53a4980>> and will run it as-is.\n",
      "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
      "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
      "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n"
     ]
    }
   ],
   "source": [
    "from m3tl.test_base import TestBase\n",
    "import numpy as np\n",
    "\n",
    "tb = TestBase()\n",
    "\n",
    "features, hidden_features = tb.get_one_batch_body_model_output()\n",
    "\n",
    "mtl_base = MTLBase(params=tb.params, name='test_mtl_base')\n",
    "\n",
    "for problem in tb.params.problem_list:\n",
    "    loss_multiplier = mtl_base.extract_feature(problem, feature_dict=features, hidden_feature_dict=hidden_features)[0]['{}_loss_multiplier'.format(problem)].numpy()\n",
    "    assert np.min(loss_multiplier) == 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class BasicMTL(MTLBase):\n",
    "    def __init__(self, params, name: str, *args, **kwargs):\n",
    "        super().__init__(params, name, *args, **kwargs)\n",
    "    \n",
    "    def call(self, inputs: Tuple[Dict[str, tf.Tensor]]):\n",
    "        mode = get_phase()\n",
    "        features, hidden_features = inputs\n",
    "        features_per_problem, hidden_features_per_problem = {}, {}\n",
    "        for problem in self.available_extract_target:\n",
    "            features_per_problem[problem], hidden_features_per_problem[problem] = self.extract_feature(\n",
    "                extract_problem=problem, feature_dict=features, hidden_feature_dict=hidden_features\n",
    "            )\n",
    "        return features_per_problem, hidden_features_per_problem"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
